Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[RFC][FEATURE] support manual parallelization strategy in shard parallel #816

Merged
merged 18 commits into from
Dec 27, 2022

Conversation

ZYHowell
Copy link
Collaborator

@ZYHowell ZYHowell commented Dec 18, 2022

Background

This PR provides an option to manually config the strategy of ShardParallel. It adopts the input the same as pjit, but replaces its Mesh by Alpa's LogicalMesh, which does not need physical resources. The api is supposed to have:

class ManualShardingOption:
  in_axis_resources: Optional[PyTree] = pjit._UNSPECIFIED,
  out_axis_resources: Optional[PyTree] = pjit._UNSPECIFIED,
  mesh_axis_names: Sequence[str] = None,

def get_sharding_spec(avals, axis_resources, logical_mesh, mesh_axis_names, option: ManualShardingOption):
  """Translate the axis resources of each aval into the ShardingSpec of each var."""

With the sharding spec of each var, we can then use hlo_module.set_spmd_parameters_shardings or set_spmd_output_sharding to set the sharding spec of the module, and let the spmd partitioner to infer the rest.

In addition, some partition strategy is defined by with_sharding_constraint in the model execution. To support it, we need another api:

def trace_jaxpr_with_pjit_constraint(fn, in_avals, logical_mesh, mesh_axis_names):
  """Trace Jaxpr which is aware of 'with_sharding_constraint'"""

Where we need to monkey-patch the with_sharding_constraint because for pipeshard parallel, the sharding spec cannot be determined yet there.

TODO

  • Support setting input and output sharding spec in ShardParallel;
  • Support partially setting input and output sharding spec in ShardParallel, then use tensorflow/compiler/xla/hlo/experimental/auto_sharding to solve under constraints;
  • Support setting input and output sharding spec in PipeshardParallel, where the mesh shape is also manually specified;
  • Support setting input and output sharding spec in PipeshardParallel, where the mesh shape is determined by automatic stage construction;
  • Support with_sharding_constraint in ShardParallel;
  • Support with_sharding_constraint in PipeshardParallel.

@ZYHowell ZYHowell changed the title [RFC][WIP][FEATURE] support manual config of shard parallel [RFC][WIP][FEATURE] support manual parallelization strategy in shard parallel Dec 18, 2022
@merrymercy merrymercy self-assigned this Dec 23, 2022
@ZYHowell ZYHowell requested a review from merrymercy December 24, 2022 17:55
@ZYHowell ZYHowell changed the title [RFC][WIP][FEATURE] support manual parallelization strategy in shard parallel [RFC][FEATURE] support manual parallelization strategy in shard parallel Dec 24, 2022
Copy link
Member

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job!

tests/shard_parallel/test_manual.py Outdated Show resolved Hide resolved
@ZYHowell ZYHowell merged commit 5660516 into main Dec 27, 2022
@ZYHowell ZYHowell deleted the pr-manual-sharding branch December 27, 2022 05:33
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants